{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Video classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ビデオ分類は、ビデオ全体にラベルまたはクラスを割り当てるタスクです。ビデオには、各ビデオに 1 つのクラスのみが含まれることが期待されます。ビデオ分類モデルはビデオを入力として受け取り、ビデオがどのクラスに属するかについての予測を返します。これらのモデルを使用して、ビデオの内容を分類できます。ビデオ分類の実際のアプリケーションはアクション/アクティビティ認識であり、フィットネス アプリケーションに役立ちます。また、視覚障害のある人にとって、特に通勤時に役立ちます。\n",
    "\n",
    "このガイドでは、次の方法を説明します。\n",
    "\n",
    "1. [UCF101](https://www.crcv.ucf.edu/) のサブセットで [VideoMAE](https://huggingface.co/docs/transformers/main/en/model_doc/videomae) を微調整します。 data/UCF101.php) データセット。\n",
    "2. 微調整したモデルを推論に使用します。\n",
    "\n",
    "<Tip>\n",
    "\n",
    "このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/video-classification) を確認することをお勧めします。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。\n",
    "```bash\n",
    "pip install -q pytorchvideo transformers evaluate\n",
    "```\n",
    "\n",
    "[PyTorchVideo](https://pytorchvideo.org/) (`pytorchvideo` と呼ばれます) を使用してビデオを処理し、準備します。\n",
    "\n",
    "モデルをアップロードしてコミュニティと共有できるように、Hugging Face アカウントにログインすることをお勧めします。プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load UCF101 dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "まず、[UCF-101 データセット](https://www.crcv.ucf.edu/data/UCF101.php) のサブセットをロードします。これにより、完全なデータセットのトレーニングにさらに時間を費やす前に、実験してすべてが機能することを確認する機会が得られます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "hf_dataset_identifier = \"sayakpaul/ucf101-subset\"\n",
    "filename = \"UCF101_subset.tar.gz\"\n",
    "file_path = hf_hub_download(repo_id=hf_dataset_identifier, filename=filename, repo_type=\"dataset\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "サブセットをダウンロードした後、圧縮アーカイブを抽出する必要があります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "\n",
    "with tarfile.open(file_path) as t:\n",
    "     t.extractall(\".\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "大まかに言うと、データセットは次のように構成されています。\n",
    "\n",
    "```bash\n",
    "UCF101_subset/\n",
    "    train/\n",
    "        BandMarching/\n",
    "            video_1.mp4\n",
    "            video_2.mp4\n",
    "            ...\n",
    "        Archery\n",
    "            video_1.mp4\n",
    "            video_2.mp4\n",
    "            ...\n",
    "        ...\n",
    "    val/\n",
    "        BandMarching/\n",
    "            video_1.mp4\n",
    "            video_2.mp4\n",
    "            ...\n",
    "        Archery\n",
    "            video_1.mp4\n",
    "            video_2.mp4\n",
    "            ...\n",
    "        ...\n",
    "    test/\n",
    "        BandMarching/\n",
    "            video_1.mp4\n",
    "            video_2.mp4\n",
    "            ...\n",
    "        Archery\n",
    "            video_1.mp4\n",
    "            video_2.mp4\n",
    "            ...\n",
    "        ...\n",
    "```\n",
    "\n",
    "(`sorted`)された ビデオ パスは次のように表示されます。\n",
    "\n",
    "\n",
    "```bash\n",
    "...\n",
    "'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi',\n",
    "'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c06.avi',\n",
    "'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01.avi',\n",
    "'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c02.avi',\n",
    "'UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g09_c06.avi'\n",
    "...\n",
    "```\n",
    "\n",
    "同じグループ/シーンに属するビデオ クリップがあり、ビデオ ファイル パスではグループが`g`で示されていることがわかります。たとえば、`v_ApplyEyeMakeup_g07_c04.avi`や`v_ApplyEyeMakeup_g07_c06.avi`などです。\n",
    "\n",
    "検証と評価の分割では、[データ漏洩](https://www.kaggle.com/code/alexisbcook/data-leakage) を防ぐために、同じグループ/シーンからのビデオ クリップを使用しないでください。このチュートリアルで使用しているサブセットでは、この情報が考慮されています。\n",
    "\n",
    "次に、データセット内に存在するラベルのセットを取得します。また、モデルを初期化するときに役立つ 2 つの辞書を作成します。\n",
    "\n",
    "* `label2id`: クラス名を整数にマップします。\n",
    "* `id2label`: 整数をクラス名にマッピングします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# Unique classes: ['ApplyEyeMakeup', 'ApplyLipstick', 'Archery', 'BabyCrawling', 'BalanceBeam', 'BandMarching', 'BaseballPitch', 'Basketball', 'BasketballDunk', 'BenchPress']."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class_labels = sorted({str(path).split(\"/\")[2] for path in all_video_file_paths})\n",
    "label2id = {label: i for i, label in enumerate(class_labels)}\n",
    "id2label = {i: label for label, i in label2id.items()}\n",
    "\n",
    "print(f\"Unique classes: {list(label2id.keys())}.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "個性的なクラスが10種類あります。トレーニング セットには、クラスごとに 30 個のビデオがあります。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a model to fine-tune"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "事前トレーニングされたチェックポイントとそれに関連する画像プロセッサからビデオ分類モデルをインスタンス化します。モデルのエンコーダーには事前トレーニングされたパラメーターが付属しており、分類ヘッドはランダムに初期化されます。画像プロセッサは、データセットの前処理パイプラインを作成するときに役立ちます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification\n",
    "\n",
    "model_ckpt = \"MCG-NJU/videomae-base\"\n",
    "image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)\n",
    "model = VideoMAEForVideoClassification.from_pretrained(\n",
    "    model_ckpt,\n",
    "    label2id=label2id,\n",
    "    id2label=id2label,\n",
    "    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルのロード中に、次の警告が表示される場合があります。\n",
    "\n",
    "```bash\n",
    "Some weights of the model checkpoint at MCG-NJU/videomae-base were not used when initializing VideoMAEForVideoClassification: [..., 'decoder.decoder_layers.1.attention.output.dense.bias', 'decoder.decoder_layers.2.attention.attention.key.weight']\n",
    "- This IS expected if you are initializing VideoMAEForVideoClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
    "- This IS NOT expected if you are initializing VideoMAEForVideoClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
    "Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at MCG-NJU/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
    "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
    "```\n",
    "\n",
    "この警告は、一部の重み (たとえば、`classifier`層の重みとバイアス) を破棄し、他のいくつかの重み (新しい`classifier`層の重みとバイアス) をランダムに初期化していることを示しています。この場合、これは予想されることです。事前にトレーニングされた重みを持たない新しい頭部を追加しているため、推論に使用する前にこのモデルを微調整する必要があるとライブラリが警告します。これはまさに私たちが行おうとしているものです。する。\n",
    "\n",
    "**注意** [このチェックポイント](https://huggingface.co/MCG-NJU/videomae-base-finetuned-kinetics) は、同様のダウンストリームで微調整されてチェックポイントが取得されたため、このタスクのパフォーマンスが向上することに注意してください。かなりのドメインの重複があるタスク。 `MCG-NJU/videomae-base-finetuned-kinetics` を微調整して取得した [このチェックポイント](https://huggingface.co/sayakpaul/videomae-base-finetuned-kinetics-finetuned-ucf101-subset) を確認できます。 -キネティクス`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare the datasets for training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ビデオの前処理には、[PyTorchVideo ライブラリ](https://pytorchvideo.org/) を利用します。まず、必要な依存関係をインポートします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorchvideo.data\n",
    "\n",
    "from pytorchvideo.transforms import (\n",
    "    ApplyTransformToKey,\n",
    "    Normalize,\n",
    "    RandomShortSideScale,\n",
    "    RemoveKey,\n",
    "    ShortSideScale,\n",
    "    UniformTemporalSubsample,\n",
    ")\n",
    "\n",
    "from torchvision.transforms import (\n",
    "    Compose,\n",
    "    Lambda,\n",
    "    RandomCrop,\n",
    "    RandomHorizontalFlip,\n",
    "    Resize,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニング データセットの変換には、均一な時間サブサンプリング、ピクセル正規化、ランダム クロッピング、およびランダムな水平反転を組み合わせて使用​​します。検証および評価のデータセット変換では、ランダムなトリミングと水平反転を除き、同じ変換チェーンを維持します。これらの変換の詳細については、[PyTorchVideo の公式ドキュメント](https://pytorchvideo.org) を参照してください。\n",
    "\n",
    "事前トレーニングされたモデルに関連付けられた`image_processor`を使用して、次の情報を取得します。\n",
    "\n",
    "* ビデオ フレームのピクセルが正規化される画像の平均値と標準偏差。\n",
    "* ビデオ フレームのサイズが変更される空間解像度。\n",
    "\n",
    "まず、いくつかの定数を定義します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = image_processor.image_mean\n",
    "std = image_processor.image_std\n",
    "if \"shortest_edge\" in image_processor.size:\n",
    "    height = width = image_processor.size[\"shortest_edge\"]\n",
    "else:\n",
    "    height = image_processor.size[\"height\"]\n",
    "    width = image_processor.size[\"width\"]\n",
    "resize_to = (height, width)\n",
    "\n",
    "num_frames_to_sample = model.config.num_frames\n",
    "sample_rate = 4\n",
    "fps = 30\n",
    "clip_duration = num_frames_to_sample * sample_rate / fps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、データセット固有の変換とデータセットをそれぞれ定義します。トレーニングセットから始めます:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_transform = Compose(\n",
    "    [\n",
    "        ApplyTransformToKey(\n",
    "            key=\"video\",\n",
    "            transform=Compose(\n",
    "                [\n",
    "                    UniformTemporalSubsample(num_frames_to_sample),\n",
    "                    Lambda(lambda x: x / 255.0),\n",
    "                    Normalize(mean, std),\n",
    "                    RandomShortSideScale(min_size=256, max_size=320),\n",
    "                    RandomCrop(resize_to),\n",
    "                    RandomHorizontalFlip(p=0.5),\n",
    "                ]\n",
    "            ),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "train_dataset = pytorchvideo.data.Ucf101(\n",
    "    data_path=os.path.join(dataset_root_path, \"train\"),\n",
    "    clip_sampler=pytorchvideo.data.make_clip_sampler(\"random\", clip_duration),\n",
    "    decode_audio=False,\n",
    "    transform=train_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同じ一連のワークフローを検証セットと評価セットに適用できます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_transform = Compose(\n",
    "    [\n",
    "        ApplyTransformToKey(\n",
    "            key=\"video\",\n",
    "            transform=Compose(\n",
    "                [\n",
    "                    UniformTemporalSubsample(num_frames_to_sample),\n",
    "                    Lambda(lambda x: x / 255.0),\n",
    "                    Normalize(mean, std),\n",
    "                    Resize(resize_to),\n",
    "                ]\n",
    "            ),\n",
    "        ),\n",
    "    ]\n",
    ")\n",
    "\n",
    "val_dataset = pytorchvideo.data.Ucf101(\n",
    "    data_path=os.path.join(dataset_root_path, \"val\"),\n",
    "    clip_sampler=pytorchvideo.data.make_clip_sampler(\"uniform\", clip_duration),\n",
    "    decode_audio=False,\n",
    "    transform=val_transform,\n",
    ")\n",
    "\n",
    "test_dataset = pytorchvideo.data.Ucf101(\n",
    "    data_path=os.path.join(dataset_root_path, \"test\"),\n",
    "    clip_sampler=pytorchvideo.data.make_clip_sampler(\"uniform\", clip_duration),\n",
    "    decode_audio=False,\n",
    "    transform=val_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**注意**: 上記のデータセット パイプラインは、[公式 PyTorchVideo サンプル](https://pytorchvideo.org/docs/tutorial_classification#dataset) から取得したものです。 [`pytorchvideo.data.Ucf101()`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.Ucf101) 関数を使用しています。 UCF-101 データセット。内部では、[`pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset`](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html#pytorchvideo.data.LabeledVideoDataset) オブジェクトを返します。 `LabeledVideoDataset` クラスは、PyTorchVideo データセット内のすべてのビデオの基本クラスです。したがって、PyTorchVideo で既製でサポートされていないカスタム データセットを使用したい場合は、それに応じて `LabeledVideoDataset` クラスを拡張できます。詳細については、`data`API [ドキュメント](https://pytorchvideo.readthedocs.io/en/latest/api/data/data.html)を参照してください。また、データセットが同様の構造 (上に示したもの) に従っている場合は、`pytorchvideo.data.Ucf101()` を使用すると問題なく動作するはずです。\n",
    "\n",
    "`num_videos` 引数にアクセスすると、データセット内のビデオの数を知ることができます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# (300, 30, 75)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualize the preprocessed video for better debugging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "...\n",
       "    The video tensor is expected to have the following shape:\n",
       "    (num_frames, num_channels, height, width).\n",
       "    \"\"\"\n",
       "    frames = []\n",
       "    for video_frame in video_tensor:\n",
       "        frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())\n",
       "        frames.append(frame_unnormalized)\n",
       "    kargs = {\"duration\": 0.25}\n",
       "    imageio.mimsave(filename, frames, \"GIF\", **kargs)\n",
       "    return filename"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import imageio\n",
    "import numpy as np\n",
    "from IPython.display import Image\n",
    "\n",
    "def unnormalize_img(img):\n",
    "    \"\"\"Un-normalizes the image pixels.\"\"\"\n",
    "    img = (img * std) + mean\n",
    "    img = (img * 255).astype(\"uint8\")\n",
    "    return img.clip(0, 255)\n",
    "\n",
    "def create_gif(video_tensor, filename=\"sample.gif\"):\n",
    "    \"\"\"Prepares a GIF from a video tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def display_gif(video_tensor, gif_name=\"sample.gif\"):\n",
    "    \"\"\"Prepares and displays a GIF from a video tensor.\"\"\"\n",
    "    video_tensor = video_tensor.permute(1, 0, 2, 3)\n",
    "    gif_filename = create_gif(video_tensor, gif_name)\n",
    "    return Image(filename=gif_filename)\n",
    "\n",
    "sample_video = next(iter(train_dataset))\n",
    "video_tensor = sample_video[\"video\"]\n",
    "display_gif(video_tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_gif.gif\" alt=\"Person playing basketball\"/>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🤗 Transformers の [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer) をモデルのトレーニングに利用します。 `Trainer`をインスタンス化するには、トレーニング構成と評価メトリクスを定義する必要があります。最も重要なのは [`TrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.TrainingArguments) で、これはトレーニングを構成するためのすべての属性を含むクラスです。モデルのチェックポイントを保存するために使用される出力フォルダー名が必要です。また、🤗 Hub 上のモデル リポジトリ内のすべての情報を同期するのにも役立ちます。\n",
    "\n",
    "トレーニング引数のほとんどは一目瞭然ですが、ここで非常に重要なのは`remove_unused_columns=False`です。これにより、モデルの呼び出し関数で使用されない機能が削除されます。デフォルトでは`True`です。これは、通常、未使用の特徴列を削除し、モデルの呼び出し関数への入力を解凍しやすくすることが理想的であるためです。ただし、この場合、`pixel_values` (モデルが入力で期待する必須キーです) を作成するには、未使用の機能 (特に`video`) が必要です。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "model_name = model_ckpt.split(\"/\")[-1]\n",
    "new_model_name = f\"{model_name}-finetuned-ucf101-subset\"\n",
    "num_epochs = 4\n",
    "\n",
    "args = TrainingArguments(\n",
    "    new_model_name,\n",
    "    remove_unused_columns=False,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=5e-5,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    warmup_steps=0.1,\n",
    "    logging_steps=10,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"accuracy\",\n",
    "    push_to_hub=True,\n",
    "    max_steps=(train_dataset.num_videos // batch_size) * num_epochs,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`pytorchvideo.data.Ucf101()` によって返されるデータセットは `__len__` メソッドを実装していません。そのため、`TrainingArguments`をインスタンス化するときに`max_steps`を定義する必要があります。\n",
    "\n",
    "次に、予測からメトリクスを計算する関数を定義する必要があります。これは、これからロードする`metric`を使用します。必要な前処理は、予測されたロジットの argmax を取得することだけです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions = np.argmax(eval_pred.predictions, axis=1)\n",
    "    return metric.compute(predictions=predictions, references=eval_pred.label_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**評価に関する注意事項**:\n",
    "\n",
    "[VideoMAE 論文](https://huggingface.co/papers/2203.12602) では、著者は次の評価戦略を使用しています。彼らはテスト ビデオからのいくつかのクリップでモデルを評価し、それらのクリップにさまざまなクロップを適用して、合計スコアを報告します。ただし、単純さと簡潔さを保つために、このチュートリアルではそれを考慮しません。\n",
    "\n",
    "また、サンプルをまとめてバッチ処理するために使用される `collat​​e_fn` を定義します。各バッチは、`pixel_values` と `labels` という 2 つのキーで構成されます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(examples):\n",
    "    # permute to (num_frames, num_channels, height, width)\n",
    "    pixel_values = torch.stack(\n",
    "        [example[\"video\"].permute(1, 0, 2, 3) for example in examples]\n",
    "    )\n",
    "    labels = torch.tensor([example[\"label\"] for example in examples])\n",
    "    return {\"pixel_values\": pixel_values, \"labels\": labels}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、これらすべてをデータセットとともに`Trainer`に渡すだけです。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(\n",
    "    model,\n",
    "    args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    "    processing_class=image_processor,\n",
    "    compute_metrics=compute_metrics,\n",
    "    data_collator=collate_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "すでにデータを前処理しているのに、なぜトークナイザーとして`image_processor`を渡したのか不思議に思うかもしれません。これは、イメージ プロセッサ構成ファイル (JSON として保存) もハブ上のリポジトリにアップロードされるようにするためだけです。\n",
    "\n",
    "次に、`train` メソッドを呼び出してモデルを微調整します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_results = trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "トレーニングが完了したら、 [push_to_hub()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.push_to_hub) メソッドを使用してモデルをハブに共有し、誰もがモデルを使用できるようにします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "モデルを微調整したので、それを推論に使用できるようになりました。\n",
    "\n",
    "推論のためにビデオをロードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_test_video = next(iter(test_dataset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "    <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/sample_gif_two.gif\" alt=\"Teams playing basketball\"/>\n",
    "</div>\n",
    "\n",
    "推論用に微調整されたモデルを試す最も簡単な方法は、それを [`pipeline`](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.VideoClassificationPipeline). で使用することです。モデルを使用してビデオ分類用の` pipeline`をインスタンス化し、それにビデオを渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.9272987842559814, 'label': 'BasketballDunk'},\n",
       " {'score': 0.017777055501937866, 'label': 'BabyCrawling'},\n",
       " {'score': 0.01663011871278286, 'label': 'BalanceBeam'},\n",
       " {'score': 0.009560945443809032, 'label': 'BandMarching'},\n",
       " {'score': 0.0068979403004050255, 'label': 'BaseballPitch'}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "video_cls = pipeline(model=\"my_awesome_video_cls_model\")\n",
    "video_cls(\"https://huggingface.co/datasets/sayakpaul/ucf101-subset/resolve/main/v_BasketballDunk_g14_c06.avi\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "必要に応じて、`pipeline`の結果を手動で複製することもできます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_inference(model, video):\n",
    "    # (num_frames, num_channels, height, width)\n",
    "    perumuted_sample_test_video = video.permute(1, 0, 2, 3)\n",
    "    inputs = {\n",
    "        \"pixel_values\": perumuted_sample_test_video.unsqueeze(0),\n",
    "        \"labels\": torch.tensor(\n",
    "            [sample_test_video[\"label\"]]\n",
    "        ),  # this can be skipped if you don't have labels available.\n",
    "    }\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    model = model.to(device)\n",
    "\n",
    "    # forward pass\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**inputs)\n",
    "        logits = outputs.logits\n",
    "\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、入力をモデルに渡し、`logits `を返します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = run_inference(trained_model, sample_test_video[\"video\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`logits` をデコードすると、次のようになります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# Predicted class: BasketballDunk"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_class_idx = logits.argmax(-1).item()\n",
    "print(\"Predicted class:\", model.config.id2label[predicted_class_idx])"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
